package manager

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"gitee.com/xfrm/middleware/internal/breaker"
	"gitee.com/xfrm/middleware/xlog"
	"gitee.com/xfrm/middleware/xsql/builder"
	"gitee.com/xfrm/middleware/xsql/scanner"
	"gitee.com/xfrm/middleware/xtime"
	"gitee.com/xfrm/middleware/xtrace"
	"github.com/VividCortex/mysqlerr"
	"github.com/go-sql-driver/mysql"
)

// Tx wrapper of sql.Tx
type Tx struct {
	tx      *sql.Tx
	cluster string
}

// ExecContext exec insert/update/delete and so on.
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
	fun := "xsql.Tx.ExecContext"
	table := tx.fetchTableName(query)
	// check breaker
	if !breaker.Entry(tx.cluster, table) {
		xlog.Errorf(ctx, "%s trigger tidb breaker, because too many timeout sqls, cluster: %s, table: %s", fun, tx.cluster, table)
		return nil, errors.New("sql cause breaker, because too many timeout")
	}
	// trace
	span, ctx := xtrace.StartSpanFromContext(ctx, fun)
	defer span.Finish()
	query = injectSQLTraceIDLineComment(ctx, query)
	setDBSpanTags(span, tx.cluster, table, fmt.Sprintf("%s %v", query, args))

	st := xtime.NewTimeStat()
	res, err := tx.tx.ExecContext(ctx, query, args...)
	statMetricReqDur(ctx, tx.cluster, table, "exec", st.Millisecond())
	// stat breaker
	breaker.StatBreaker(tx.cluster, table, err)
	statMetricReqErrTotal(ctx, tx.cluster, table, "exec", err)
	return res, err
}

// QueryContext executes a query that returns rows, typically a SELECT.
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
	fun := "xsql.Tx.QueryContext"
	table := tx.fetchTableName(query)
	// check breaker
	if !breaker.Entry(tx.cluster, table) {
		xlog.Errorf(ctx, "%s trigger tidb breaker, because too many timeout sqls, cluster: %s, table: %s", fun, tx.cluster, table)
		return nil, errors.New("sql cause breaker, because too many timeout")
	}
	// trace
	span, ctx := xtrace.StartSpanFromContext(ctx, fun)
	defer span.Finish()
	query = injectSQLTraceIDLineComment(ctx, query)
	setDBSpanTags(span, tx.cluster, table, fmt.Sprintf("%s %v", query, args))

	st := xtime.NewTimeStat()
	res, err := tx.tx.QueryContext(ctx, query, args...)
	statMetricReqDur(ctx, tx.cluster, table, "query", st.Millisecond())
	// stat breaker
	breaker.StatBreaker(tx.cluster, table, err)
	statMetricReqErrTotal(ctx, tx.cluster, table, "query", err)
	return res, err
}

// QueryRowContext executes a query that is expected to return at most one row.
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
	fun := "xsql.Tx.QueryRowContext"
	table := tx.fetchTableName(query)
	// check breaker
	if !breaker.Entry(tx.cluster, table) {
		xlog.Errorf(ctx, "%s trigger tidb breaker, because too many timeout sqls, cluster: %s, table: %s", fun, tx.cluster, table)
		return nil
	}
	// trace
	span, ctx := xtrace.StartSpanFromContext(ctx, fun)
	defer span.Finish()
	query = injectSQLTraceIDLineComment(ctx, query)
	setDBSpanTags(span, tx.cluster, table, fmt.Sprintf("%s %v", query, args))

	st := xtime.NewTimeStat()
	res := tx.tx.QueryRowContext(ctx, query, args...)
	statMetricReqDur(ctx, tx.cluster, table, "query row", st.Millisecond())
	return res
}

func (tx *Tx) SelectOne(ctx context.Context, table string, where map[string]interface{}, item interface{}) error {
	if nil == tx {
		return errors.New("manager.XDB object couldn't be nil")
	}
	copyWhere := copyWhere(where)
	if _, ok := copyWhere["_limit"]; !ok {
		copyWhere["_limit"] = []uint{0, 1}
	}
	cond, vals, err := builder.BuildSelectWithContext(ctx, table, copyWhere, nil)
	if nil != err {
		return err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	row, err := tx.QueryContext(ctx, cond, vals...)
	if nil != err || nil == row {
		return err
	}
	defer row.Close()
	err = scanner.Scan(row, item)
	return err
}

func (tx *Tx) Select(ctx context.Context, table string, where map[string]interface{}, items interface{}) error {
	if nil == tx {
		return errors.New("manager.XDB object couldn't be nil")
	}
	cond, vals, err := builder.BuildSelectWithContext(ctx, table, where, nil)
	if nil != err {
		return err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	row, err := tx.QueryContext(ctx, cond, vals...)
	if nil != err || nil == row {
		return err
	}
	defer row.Close()
	err = scanner.Scan(row, items)
	return err
}

func (tx *Tx) Insert(ctx context.Context, table string, data []map[string]interface{}) (int64, error) {
	if nil == tx {
		return 0, errors.New("manager.XDB object couldn't be nil")
	}
	cond, vals, err := builder.BuildInsert(table, data)
	if nil != err {
		return 0, err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	result, err := tx.ExecContext(ctx, cond, vals...)
	if nil != err || nil == result {
		return 0, err
	}
	return result.LastInsertId()
}

func (tx *Tx) Update(ctx context.Context, table string, where, data map[string]interface{}) (int64, error) {
	if nil == tx {
		return 0, errors.New("manager.XDB object couldn't be nil")
	}
	cond, vals, err := builder.BuildUpdate(table, where, data)
	if nil != err {
		return 0, err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	result, err := tx.ExecContext(ctx, cond, vals...)
	if nil != err {
		return 0, err
	}
	return result.RowsAffected()
}

func (tx *Tx) Upsert(ctx context.Context, table string, where, data, insertSet map[string]interface{}) (int64, error) {
	if nil == tx {
		return 0, errors.New("manager.XDB object couldn't be nil")
	}
	var insertData = make(map[string]interface{})
	for k, v := range data {
		insertData[k] = v
	}
	for k, v := range insertSet {
		insertData[k] = v
	}
	lastID, err := tx.Insert(ctx, table, []map[string]interface{}{insertData})
	if err == nil {
		return lastID, nil
	}
	mysqlErr, ok := err.(*mysql.MySQLError)
	if ok && mysqlerr.ER_DUP_ENTRY == int(mysqlErr.Number) {
		return tx.Update(ctx, table, where, data)
	}
	return lastID, err
}

func (tx *Tx) Delete(ctx context.Context, table string, where map[string]interface{}) (int64, error) {
	if nil == tx {
		return 0, errors.New("manager.XDB object couldn't be nil")
	}
	cond, vals, err := builder.BuildDelete(table, where)
	if nil != err {
		return 0, err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	result, err := tx.ExecContext(ctx, cond, vals...)
	if nil != err {
		return 0, err
	}
	return result.RowsAffected()
}

func (tx *Tx) SelectCount(ctx context.Context, table string, where map[string]interface{}) (count int, err error) {
	if nil == tx {
		return 0, errors.New("manager.XDB object couldn't be nil")
	}
	cond, vals, err := builder.BuildSelect(table, where, []string{builder.AggregateCount("*").Symble()})
	if nil != err {
		return 0, err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	result, err := tx.QueryContext(ctx, cond, vals...)
	if nil != err {
		return 0, err
	}
	for result.Next() {
		err = result.Scan(&count)
		if err != nil {
			return
		}
	}
	return
}

// Commit wrapper of sql.Tx commit
func (tx *Tx) Commit(ctx context.Context) error {
	// trace
	span, ctx := xtrace.StartSpanFromContext(ctx, "xsql.Tx.Commit")
	defer span.Finish()
	setDBSpanTags(span, tx.cluster, tx.cluster, "")

	st := xtime.NewTimeStat()
	err := tx.tx.Commit()
	statMetricReqDur(ctx, tx.cluster, tx.cluster, "commit", st.Millisecond())
	statMetricReqErrTotal(ctx, tx.cluster, tx.cluster, "commit", err)
	return err
}

// Rollback wrapper of sql.Tx rollback
func (tx *Tx) Rollback(ctx context.Context) error {
	// trace
	span, ctx := xtrace.StartSpanFromContext(ctx, "xsql.Tx.Commit")
	defer span.Finish()
	setDBSpanTags(span, tx.cluster, tx.cluster, "")

	st := xtime.NewTimeStat()
	err := tx.tx.Rollback()
	statMetricReqDur(ctx, tx.cluster, tx.cluster, "rollback", st.Millisecond())
	statMetricReqErrTotal(ctx, tx.cluster, tx.cluster, "rollback", err)
	return err
}

func (tx *Tx) fetchTableName(query string) (table string) {
	table = extractSQLTableName(query)
	if table != "" {
		return
	}

	if tx != nil {
		return tx.cluster
	}

	return
}
